{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For Colab only!\n",
    "\n",
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.0.0\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "print(tf.__version__)\n",
    "print(tf.test.is_gpu_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.4.0\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scalar Vector and Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "()\n",
      "0\n",
      "tf.Tensor(2.4, shape=(), dtype=float32)\n",
      "tf.Tensor([2.4], shape=(1,), dtype=float32)\n",
      "dim =  1\n",
      "tf.Tensor(\n",
      "[[1 2]\n",
      " [3 4]], shape=(2, 2), dtype=int32)\n",
      "dim =  2\n",
      "tf.Tensor(\n",
      "[[[1 2]\n",
      "  [3 4]]], shape=(1, 2, 2), dtype=int32)\n",
      "dim =  3\n",
      "tf.Tensor(\n",
      "[[[1]\n",
      "  [2]]\n",
      "\n",
      " [[3]\n",
      "  [4]]], shape=(2, 2, 1), dtype=int32)\n",
      "dim =  3\n"
     ]
    }
   ],
   "source": [
    "#  dim = 0 \n",
    "a = tf.constant(2.4)\n",
    "print(a.shape)\n",
    "print(len(a.shape))\n",
    "print(a)\n",
    "\n",
    "# dim = 1 \n",
    "a = tf.constant([2.4])\n",
    "print(a)\n",
    "print(\"dim = \",len(a.shape))\n",
    "\n",
    "# dim = 2\n",
    "b = tf.constant([[1,2],[3,4]])\n",
    "print(b)\n",
    "print(\"dim = \",len(b.shape))\n",
    "\n",
    "c = tf.constant([[[1,2],[3,4]]])\n",
    "print(c)\n",
    "print(\"dim = \",len(c.shape))\n",
    "\n",
    "d = tf.constant([[[1],[2]],[[3],[4]]])\n",
    "print(d)\n",
    "print(\"dim = \",len(d.shape))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([])\n",
      "0\n",
      "tensor(2.4000)\n",
      "tensor([2.4000])\n",
      "dim =  1\n",
      "tensor([[1, 2],\n",
      "        [3, 4]])\n",
      "dim =  2\n",
      "tensor([[[1, 2],\n",
      "         [3, 4]]])\n",
      "dim =  3\n",
      "tensor([[[1],\n",
      "         [2]],\n",
      "\n",
      "        [[3],\n",
      "         [4]]])\n",
      "dim =  3\n"
     ]
    }
   ],
   "source": [
    "#  dim = 0 \n",
    "a = torch.tensor(2.4)\n",
    "print(a.shape)\n",
    "print(len(a.shape))\n",
    "print(a)\n",
    "\n",
    "# dim = 1 \n",
    "a = torch.tensor([2.4])\n",
    "print(a)\n",
    "print(\"dim = \",len(a.shape))\n",
    "\n",
    "# dim = 2\n",
    "b = torch.tensor([[1,2],[3,4]])\n",
    "print(b)\n",
    "print(\"dim = \",len(b.shape))\n",
    "\n",
    "c = torch.tensor([[[1,2],[3,4]]])\n",
    "print(c)\n",
    "print(\"dim = \",len(c.shape))\n",
    "\n",
    "d = torch.tensor([[[1],[2]],[[3],[4]]])\n",
    "print(d)\n",
    "print(\"dim = \",len(d.shape))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* A tf.Variable represents a tensor whose value can be changed by running ops on it. Specific ops allow you to read and modify the values of this tensor.\n",
    "\n",
    "* A torch.Tensor can initialize tensor with certain shape. But this is not a recommded initialization methode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "()\n",
      "0\n",
      "<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.4>\n",
      "<tf.Variable 'Variable:0' shape=(1,) dtype=float32, numpy=array([2.4], dtype=float32)>\n",
      "dim =  1\n",
      "<tf.Variable 'Variable:0' shape=(2, 2) dtype=int32, numpy=\n",
      "array([[1, 2],\n",
      "       [3, 4]], dtype=int32)>\n",
      "dim =  2\n",
      "<tf.Variable 'Variable:0' shape=(1, 2, 2) dtype=int32, numpy=\n",
      "array([[[1, 2],\n",
      "        [3, 4]]], dtype=int32)>\n",
      "dim =  3\n",
      "<tf.Variable 'Variable:0' shape=(2, 2, 1) dtype=int32, numpy=\n",
      "array([[[1],\n",
      "        [2]],\n",
      "\n",
      "       [[3],\n",
      "        [4]]], dtype=int32)>\n",
      "dim =  3\n"
     ]
    }
   ],
   "source": [
    "#  dim = 0 \n",
    "a = tf.Variable(2.4)\n",
    "print(a.shape)\n",
    "print(len(a.shape))\n",
    "print(a)\n",
    "\n",
    "# dim = 1 \n",
    "a = tf.Variable([2.4])\n",
    "print(a)\n",
    "print(\"dim = \",len(a.shape))\n",
    "\n",
    "# dim = 2\n",
    "b = tf.Variable([[1,2],[3,4]])\n",
    "print(b)\n",
    "print(\"dim = \",len(b.shape))\n",
    "\n",
    "c = tf.Variable([[[1,2],[3,4]]])\n",
    "print(c)\n",
    "print(\"dim = \",len(c.shape))\n",
    "\n",
    "d = tf.Variable([[[1],[2]],[[3],[4]]])\n",
    "print(d)\n",
    "print(\"dim = \",len(d.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4.9886e-15])\n",
      "dim =  1\n",
      "tensor([[         6,          0, 1167094992],\n",
      "        [         1,          0,          0]], dtype=torch.int32)\n",
      "dim =  2\n",
      "tensor([[[2.9427e-44, 0.0000e+00, 9.4344e+35, 4.5898e-41],\n",
      "         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "         [0.0000e+00, 0.0000e+00, 1.8101e+33, 4.5898e-41]],\n",
      "\n",
      "        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
      "         [1.8101e+33, 4.5898e-41, 0.0000e+00, 0.0000e+00],\n",
      "         [0.0000e+00, 0.0000e+00,        nan,        nan]]])\n",
      "torch.Size([2, 3, 4])\n",
      "dim =  3\n",
      "tensor([[[1.6816e-44, 6.8247e-06, 0.0000e+00, 5.9062e+04],\n",
      "         [0.0000e+00, 1.7165e-05, 0.0000e+00, 2.7432e+04],\n",
      "         [0.0000e+00, 1.7165e-05, 0.0000e+00, 2.7432e+04]],\n",
      "\n",
      "        [[0.0000e+00, 1.7165e-05, 1.4013e-45, 4.5918e-40],\n",
      "         [4.6200e+03, 1.4013e-45, 0.0000e+00, 0.0000e+00],\n",
      "         [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00]]])\n",
      "torch.Size([2, 3, 4])\n",
      "dim =  3\n"
     ]
    }
   ],
   "source": [
    "# dim = 1\n",
    "# x = torch.Tensor(1)\n",
    "\n",
    "print(x)\n",
    "print('dim = ', len(x.shape))\n",
    "\n",
    "\n",
    "# dim = 2\n",
    "x = torch.IntTensor(2,3)\n",
    "print(x)\n",
    "print('dim = ', len(x.shape))\n",
    "\n",
    "# dim = 3\n",
    "x = torch.FloatTensor(2,3,4)\n",
    "print(x)\n",
    "print(x.shape)\n",
    "print('dim = ', len(x.shape))\n",
    "\n",
    "x = torch.empty(2,3,4)\n",
    "print(x)\n",
    "print(x.shape)\n",
    "print('dim = ', len(x.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor Initialization\n",
    "* From numpy or list\n",
    "* All one/zero/eye\n",
    "* Random normal/uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<tf.Variable 'Variable:0' shape=(3,) dtype=int64, numpy=array([1, 2, 3])>\n",
      "[1 2 3]\n",
      "[1, 2, 3]\n",
      "tf.Tensor(\n",
      "[[1. 1. 1.]\n",
      " [1. 1. 1.]\n",
      " [1. 1. 1.]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[[0. 0. 0.]\n",
      "  [0. 0. 0.]]], shape=(1, 2, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[1. 0. 0. 0.]\n",
      " [0. 1. 0. 0.]\n",
      " [0. 0. 1. 0.]\n",
      " [0. 0. 0. 1.]], shape=(4, 4), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[-0.92840195 -0.21091819 -0.2701943 ]\n",
      " [ 0.83940816  0.4870708   0.313586  ]\n",
      " [ 0.8899615   0.62667656 -0.33387423]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[-0.00574434 -1.873318   -0.187527  ]\n",
      " [ 0.36099178 -0.87048984  0.36814544]\n",
      " [ 0.47805578 -0.08826071 -1.377957  ]], shape=(3, 3), dtype=float32)\n",
      "tf.Tensor(\n",
      "[[3 3 3]\n",
      " [3 3 3]\n",
      " [3 3 3]], shape=(3, 3), dtype=int32)\n",
      "(3, 3)\n"
     ]
    }
   ],
   "source": [
    "a = np.array([1,2,3])\n",
    "aa = tf.Variable(a)\n",
    "print(aa)\n",
    "\n",
    "# back to numpy, list and item\n",
    "print(aa.numpy())\n",
    "print(aa.numpy().tolist())\n",
    "\n",
    "b = tf.ones(shape=[3,3])\n",
    "print(b)\n",
    "\n",
    "c = tf.zeros(shape=[1,2,3])\n",
    "print(c)\n",
    "\n",
    "d = tf.eye(4)\n",
    "print(d)\n",
    "\n",
    "e = tf.random.uniform([3,3],\n",
    "                     minval=-1,\n",
    "                     maxval= 1)\n",
    "print(e)\n",
    "\n",
    "f = tf.random.normal([3,3])\n",
    "print(f)\n",
    "\n",
    "g = tf.fill(dims=[3,3], value=3)\n",
    "print(g)\n",
    "print(g.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.2205,  1.1558, -1.1998],\n",
      "        [ 0.5872, -0.1655, -0.8753],\n",
      "        [-0.6385,  0.3930, -0.9416]])\n"
     ]
    }
   ],
   "source": [
    "a = np.array([1,2,3])\n",
    "aa = torch.from_numpy(a)\n",
    "aa = torch.tensor(a)\n",
    "print(aa)\n",
    "\n",
    "# back to numpy, list, item\n",
    "print(aa.numpy())\n",
    "print(aa.tolist())\n",
    "print(aa[0].item())\n",
    "\n",
    "b = torch.ones([3,3])\n",
    "print(b)\n",
    "\n",
    "c = torch.zeros([1,2,3])\n",
    "print(c)\n",
    "\n",
    "d = torch.eye(4)\n",
    "print(d)\n",
    "\n",
    "e = torch.rand([3,3]) *2 -1\n",
    "print(e)\n",
    "\n",
    "f = torch.empty(3,3)\n",
    "f = f.normal_()\n",
    "print(f)\n",
    "\n",
    "f = torch.tensor(np.random.normal( size= [3,3]))\n",
    "print(f)\n",
    "\n",
    "g = torch.rand_like(e)\n",
    "print(g)\n",
    "\n",
    "h = torch.full([3,3],fill_value = 3)\n",
    "print(h)\n",
    "print(h.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_style": "center"
   },
   "source": [
    "### Indexing and slicing\n",
    "\n",
    "start : end: step (end is not included)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor([0.16387904 0.4265666  0.5988269 ], shape=(3,), dtype=float32)\n",
      "tf.Tensor([0.16387904 0.4265666  0.5988269 ], shape=(3,), dtype=float32)\n",
      "tf.Tensor([0.16387904 0.4265666  0.5988269 ], shape=(3,), dtype=float32)\n",
      "(28,)\n",
      "(14, 14)\n",
      "(28, 28)\n"
     ]
    }
   ],
   "source": [
    "x = tf.random.uniform([2,28,28,3])\n",
    "\n",
    "print(x[0][0][0])\n",
    "print(x[0,0,0])\n",
    "print(x[0,0,0,:])\n",
    "\n",
    "print(x[0,:,0,0].shape)\n",
    "\n",
    "print(x[1,::2,::2,1].shape)\n",
    "\n",
    "print(x[0,...,0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.6419, 0.2041, 0.1747, 0.9867, 0.2958, 0.4511, 0.1343, 0.1190, 0.1462,\n",
      "        0.6560, 0.7630, 0.1408, 0.6491, 0.4121, 0.7311, 0.5382, 0.1502, 0.2961,\n",
      "        0.7076, 0.5014, 0.8862, 0.9230, 0.0037, 0.2196, 0.5322, 0.4499, 0.2151,\n",
      "        0.9921])\n",
      "tensor([0.6419, 0.2041, 0.1747, 0.9867, 0.2958, 0.4511, 0.1343, 0.1190, 0.1462,\n",
      "        0.6560, 0.7630, 0.1408, 0.6491, 0.4121, 0.7311, 0.5382, 0.1502, 0.2961,\n",
      "        0.7076, 0.5014, 0.8862, 0.9230, 0.0037, 0.2196, 0.5322, 0.4499, 0.2151,\n",
      "        0.9921])\n",
      "torch.Size([28])\n",
      "torch.Size([14, 14])\n",
      "torch.Size([3, 28])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand([2,3,28,28])\n",
    "print(x[0][0][0])\n",
    "print(x[0,0,0])\n",
    "print(x[0,0,0,:].shape)\n",
    "\n",
    "print(x[0,1,::2,::2].shape)\n",
    "\n",
    "print(x[0,...,0].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reshape\n",
    "\n",
    "Image shape: [b, w, h, c] or [b, c, w, h]\n",
    "\n",
    "Reshape, Squeeze, Unsqueeze, Transpose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784, 3)\n",
      "(5, 14, 56, 3)\n",
      "(5, 28, 28, 3)\n",
      "(1, 5, 28, 28, 3)\n",
      "y shape:  (5, 28, 1, 28, 3)\n",
      "(5, 28, 28, 3)\n",
      "z shape (5, 3, 28, 28)\n",
      "(5, 28, 28, 3)\n"
     ]
    }
   ],
   "source": [
    "x = tf.random.uniform([5,28,28,3])\n",
    "\n",
    "print(tf.reshape(x, [5, 28*28, 3]).shape)\n",
    "print(tf.reshape(x, [5, 14, 56, 3]).shape)\n",
    "\n",
    "print(x.shape)\n",
    "\n",
    "print(tf.expand_dims(x, axis=0).shape)\n",
    "\n",
    "y = tf.expand_dims(x, axis=2)\n",
    "print('y shape: ',y.shape)\n",
    "print(tf.squeeze(y, axis=2).shape)\n",
    "\n",
    "# Change dim [b,w,h,c] -> [b,c,w,h]\n",
    "\n",
    "z = tf.transpose(x,perm=[0,3,1,2])\n",
    "print('z shape', z.shape)\n",
    "\n",
    "print(x.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "cell_style": "split",
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 3, 784])\n",
      "torch.Size([5, 3, 784])\n",
      "torch.Size([5, 3, 784])\n",
      "torch.Size([5, 3, 14, 56])\n",
      "y shape:  torch.Size([5, 3, 1, 28, 28])\n",
      "torch.Size([5, 3, 28, 28])\n",
      "z shape:  torch.Size([5, 28, 28, 3])\n",
      "torch.Size([5, 28, 3, 28])\n",
      "torch.Size([5, 28, 3, 28])\n",
      "torch.Size([5, 3, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand([5,3,28,28])\n",
    "\n",
    "print(x.view(5,3,28*28).shape)\n",
    "print(x.reshape(5,3,28*28).shape)\n",
    "print(torch.reshape(x,[5,3,28*28]).shape)\n",
    "\n",
    "print(x.reshape(5,3,14,56).shape)\n",
    "\n",
    "y = x.unsqueeze(2)\n",
    "print('y shape: ', y.shape)\n",
    "\n",
    "print(y.squeeze(2).shape)\n",
    "\n",
    "# exchange dim [b, c, w, h] -> [b, w, h, c]\n",
    "z = x.transpose(1, 2).transpose(2, 3)\n",
    "print('z shape: ', z.shape)\n",
    "\n",
    "print(torch.transpose(x, 1,2).shape)\n",
    "print(x.permute(0,3,1,2).shape)\n",
    "\n",
    "print(x.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Broadcasting\n",
    "\n",
    "Manully broadcast with `tf.broadcast_to` / `y.expand(*size)`.\n",
    "\n",
    "Automatic broadcast in calculation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 28, 28, 3)\n",
      "(4, 28, 28, 3)\n"
     ]
    }
   ],
   "source": [
    "x = tf.random.uniform([4,28,28,3])\n",
    "y = tf.random.uniform([1,28,1,1])\n",
    "\n",
    "print((tf.broadcast_to(y,[4,28,28,3])).shape)\n",
    "\n",
    "print((x+y).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 3, 28, 28])\n",
      "torch.Size([4, 3, 28, 28])\n"
     ]
    }
   ],
   "source": [
    "x = torch.rand([4,3,28,28])\n",
    "y = torch.rand([1,1,28,1])\n",
    "\n",
    "print(y.expand(4,3,28,28).shape)\n",
    "\n",
    "print((x+y).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 28, 28, 3)\n"
     ]
    }
   ],
   "source": [
    "x = tf.random.uniform([4,28,28,3])\n",
    "y = tf.random.uniform([4,28,1,1])\n",
    "\n",
    "print((x + y).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Merge and Split"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_style": "split"
   },
   "source": [
    "* tf.concat \n",
    "* tf.split\n",
    "* tf.stack\n",
    "* tf.unstack"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "cell_style": "split"
   },
   "source": [
    "* Cat\n",
    "* Stack\n",
    "* Split\n",
    "* Chunk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(9, 32, 8)\n",
      "(4, 32, 8)\n",
      "(4, 32, 8)\n",
      "(3, 32, 8) (6, 32, 8)\n",
      "(3, 32, 8) (3, 32, 8) (3, 32, 8)\n",
      "(32, 8) (32, 8) (32, 8)\n",
      "(3, 32, 8)\n",
      "(3, 3, 32, 8)\n",
      "(2, 4, 32, 8)\n"
     ]
    }
   ],
   "source": [
    "# Example [classes, students, scores] \n",
    "# Concat dim = 0, \n",
    "\n",
    "a = tf.random.uniform([4,32,8])\n",
    "b = tf.random.uniform([5,32,8])\n",
    "print(tf.concat([a,b], axis=0).shape)\n",
    "\n",
    "# Concat dim = 1, Same classes different students\n",
    "a = tf.random.uniform([4,10,8])\n",
    "b = tf.random.uniform([4,22,8])\n",
    "print(tf.concat([a,b], axis=1).shape)\n",
    "\n",
    "#Concat dim = 2, Same classes, students, different subjects\n",
    "a = tf.random.uniform([4,32,5])\n",
    "b = tf.random.uniform([4,32,3])\n",
    "print(tf.concat([a,b], axis=2).shape)\n",
    "\n",
    "c = tf.random.uniform([9,32,8])\n",
    "\n",
    "c1, c2 = tf.split(c,[3,6],axis=0)\n",
    "print(c1.shape, c2.shape)\n",
    "\n",
    "\n",
    "c1, c2, c3 = tf.split(c,3,axis=0)\n",
    "print(c1.shape, c2.shape, c3.shape)\n",
    "\n",
    "d1, d2, d3 = tf.unstack(c1,axis=0)\n",
    "print(d1.shape, d2.shape, d3.shape)\n",
    "\n",
    "print(tf.stack([d1,d2,d3], axis=0).shape)\n",
    "print(tf.stack([c1,c2,c3], axis=0).shape)\n",
    "\n",
    "# Example [classes, students, scores] -> [grades, classes, students, scores]\n",
    "a = tf.random.uniform([4,32,8])\n",
    "b = tf.random.uniform([4,32,8])\n",
    "print(tf.stack([a,b], axis=0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "cell_style": "split"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([9, 32, 8])\n",
      "torch.Size([4, 32, 8])\n",
      "torch.Size([4, 32, 8])\n",
      "torch.Size([3, 32, 8]) torch.Size([6, 32, 8])\n",
      "torch.Size([3, 32, 8]) torch.Size([3, 32, 8]) torch.Size([3, 32, 8])\n",
      "torch.Size([3, 32, 8]) torch.Size([3, 32, 8]) torch.Size([3, 32, 8])\n",
      "torch.Size([2, 4, 32, 8])\n"
     ]
    }
   ],
   "source": [
    "# Example [classes, students, scores] \n",
    "# Concat dim = 0, \n",
    "\n",
    "a = torch.rand(4,32,8)\n",
    "b = torch.rand(5,32,8)\n",
    "print(torch.cat([a,b], dim=0).shape)\n",
    "\n",
    "# Concat dim = 1, Same classes different students\n",
    "a = torch.rand(4,10,8)\n",
    "b = torch.rand(4,22,8)\n",
    "print(torch.cat([a,b], dim=1).shape)\n",
    "\n",
    "#Concat dim = 2, Same classes, students, different subjects\n",
    "a = torch.rand(4,32,5)\n",
    "b = torch.rand(4,32,3)\n",
    "print(torch.cat([a,b], dim=2).shape)\n",
    "\n",
    "\n",
    "c = torch.rand(9,32,8)\n",
    "# c1, c2 = torch.split(c,[3,6],dim=0)\n",
    "c1, c2 = c.split([3,6],dim=0)\n",
    "print(c1.shape, c2.shape)\n",
    "\n",
    "# c1, c2, c3 = torch.split(c,3,dim=0)\n",
    "c1, c2, c3 = c.split(3,dim=0)\n",
    "print(c1.shape, c2.shape, c3.shape)\n",
    "\n",
    "# c1, c2, c3 = torch.chunk(c,3,dim=0)\n",
    "c1, c2, c3 = c.chunk(3,dim=0)\n",
    "print(c1.shape, c2.shape, c3.shape)\n",
    "\n",
    "\n",
    "# Example [classes, students, scores] -> [grades, classes, students, scores]\n",
    "a = torch.rand(4,32,8)\n",
    "b = torch.rand(4,32,8)\n",
    "print(torch.stack([a,b],dim=0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
